{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Amazon SageMaker Object Detection for Bird Species\n",
    "\n",
    "1. [Introduction](#Introduction)\n",
    "2. [Setup](#Setup)\n",
    "3. [Data Preparation](#Data-Preparation)\n",
    "  1. [Download and unpack the dataset](#Download-and-unpack-the-dataset)\n",
    "  2. [Understand the dataset](#Understand-the-dataset)\n",
    "  3. [Generate RecordIO files](#Generate-RecordIO-files)\n",
    "4. [Train the model](#Train-the-model)\n",
    "5. [Host the model](#Host-the-model)\n",
    "6. [Test the model](#Test-the-model)\n",
    "7. [Clean up](#Clean-up)\n",
    "8. [Improve the model](#Improve-the-model)\n",
    "9. [Final cleanup](#Final-cleanup)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Introduction\n",
    "\n",
    "Object detection is the process of identifying and localizing objects in an image. A typical object detection solution takes an image as input and provides a bounding box on the image where an object of interest is found.  It also identifies what type of object the box encapsulates.  To create such a solution, we need to acquire and process a traning dataset, create and setup a training job for the alorithm so that it can learn about the dataset. Finally, we can then host the trained model in an endpoint, to which we can supply images.\n",
    "\n",
    "This notebook is an end-to-end example showing how the Amazon SageMaker Object Detection algorithm can be used with a publicly available dataset of bird images. We demonstrate how to train and to host an object detection model based on the [Caltech Birds (CUB 200 2011)](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html) dataset.  Amazon SageMaker's object detection algorithm uses the Single Shot multibox Detector ([SSD](https://arxiv.org/abs/1512.02325)) algorithm, and this notebook uses a [ResNet](https://arxiv.org/pdf/1603.05027.pdf) base network with that algorithm.\n",
    "\n",
    "![Sample results detecting a pair of goldfinch on a feeder](./goldfinch_detections.png)\n",
    "\n",
    "We will also demonstrate how to construct a training dataset using the RecordIO format, as this is the format that the training job consumes.  This notebook is similar to the [Object Detection using the RecordIO format](https://github.com/awslabs/amazon-sagemaker-examples/blob/master/introduction_to_amazon_algorithms/object_detection_pascalvoc_coco/object_detection_recordio_format.ipynb) notebook, with the following key differences:\n",
    "\n",
    "- We provide an example of how to translate bounding box specifications when providing images to SageMaker's algorithm. You will see code for generating the train.lst and val.lst files used to create [recordIO](https://mxnet.incubator.apache.org/architecture/note_data_loading.html) files.\n",
    "- We demonstrate how to improve an object detection model by adding training images that are flipped horizontally (mirror images).\n",
    "- We give you a notebook for experimenting with object detection challenges with an order of magnitude more classes (200 bird species, as opposed to the 20 categories used by [Pascal VOC](http://host.robots.ox.ac.uk/pascal/VOC/)).\n",
    "- We show how to chart the accuracy improvements that occur across the epochs of the training job.\n",
    "\n",
    "Note that Amazon SageMaker Object Detection also allows training with the image and JSON format, which is illustrated in the [image and JSON Notebook](https://github.com/awslabs/amazon-sagemaker-examples/blob/master/introduction_to_amazon_algorithms/object_detection_pascalvoc_coco/object_detection_image_json_format.ipynb)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup\n",
    "\n",
    "Before preparing the data, there are some initial steps required for setup.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook requires two additional Python packages:\n",
    "* **OpenCV** is required for gathering image sizes and flipping of images horizontally.\n",
    "* The **MXNet** runtime is required for using the im2rec tool."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "!{sys.executable} -m pip install opencv-python\n",
    "!{sys.executable} -m pip install mxnet"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We need to identify the S3 bucket that you want to use for providing training and validation datasets.  It will also be used to store the tranied model artifacts. In this notebook, we use a custom bucket. You could alternatively use a default bucket for the session.  We use an object prefix to help organize the bucket content."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bucket = '<your_s3_bucket_name_here>' # custom bucket name.\n",
    "prefix = 'DEMO-ObjectDetection-birds'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To train the Object Detection algorithm on Amazon SageMaker, we need to setup and authenticate the use of AWS services. To begin with, we need an AWS account role with SageMaker access. Here we will use the execution role the current notebook instance was given when it was created.  This role has necessary permissions, including access to your data in S3."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sagemaker\n",
    "from sagemaker import get_execution_role\n",
    "\n",
    "role = get_execution_role()\n",
    "print(role)\n",
    "sess = sagemaker.Session()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data Preparation\n",
    "\n",
    "The [Caltech Birds (CUB 200 2011)](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html) dataset contains 11,788 images across 200 bird species (the original technical report can be found [here](http://www.vision.caltech.edu/visipedia/papers/CUB_200_2011.pdf)).  Each species comes with around 60 images, with a typical size of about 350 pixels by 500 pixels.  Bounding boxes are provided, as are annotations of bird parts.  A recommended train/test split is given, but image size data is not.\n",
    "\n",
    "![](./cub_200_2011_snapshot.png)\n",
    "\n",
    "The dataset can be downloaded [here](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html).\n",
    "\n",
    "## Download and unpack the dataset\n",
    "\n",
    "Here we download the birds dataset from CalTech."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os \n",
    "import urllib.request\n",
    "\n",
    "def download(url):\n",
    "    filename = url.split('/')[-1]\n",
    "    if not os.path.exists(filename):\n",
    "        urllib.request.urlretrieve(url, filename)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "download('http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we unpack the dataset into its own directory structure."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "# Clean up prior version of the downloaded dataset if you are running this again\n",
    "!rm -rf CUB_200_2011  \n",
    "\n",
    "# Unpack and then remove the downloaded compressed tar file\n",
    "!gunzip -c ./CUB_200_2011.tgz | tar xopf - \n",
    "!rm CUB_200_2011.tgz"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Understand the dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set some parameters for the rest of the notebook to use"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here we define a few parameters that help drive the rest of the notebook.  For example, `SAMPLE_ONLY` is defaulted to `True`. This will force the notebook to train on only a handful of species.  Setting to false will make the notebook work with the entire dataset of 200 bird species.  This makes the training a more difficult challenge, and you will need many more epochs to complete.\n",
    "\n",
    "The file parameters define names and locations of metadata files for the dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import cv2\n",
    "import boto3\n",
    "import json\n",
    "\n",
    "runtime = boto3.client(service_name='runtime.sagemaker')\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "RANDOM_SPLIT = False\n",
    "SAMPLE_ONLY  = True\n",
    "FLIP         = False\n",
    "\n",
    "# To speed up training and experimenting, you can use a small handful of species.\n",
    "# To see the full list of the classes available, look at the content of CLASSES_FILE.\n",
    "CLASSES = [17, 36, 47, 68, 73]\n",
    "\n",
    "# Otherwise, you can use the full set of species\n",
    "if (not SAMPLE_ONLY):\n",
    "    CLASSES = []\n",
    "    for c in range(200):\n",
    "        CLASSES += [c + 1]\n",
    "\n",
    "RESIZE_SIZE = 256\n",
    "\n",
    "BASE_DIR   = 'CUB_200_2011/'\n",
    "IMAGES_DIR = BASE_DIR + 'images/'\n",
    "\n",
    "CLASSES_FILE = BASE_DIR + 'classes.txt'\n",
    "BBOX_FILE    = BASE_DIR + 'bounding_boxes.txt'\n",
    "IMAGE_FILE   = BASE_DIR + 'images.txt'\n",
    "LABEL_FILE   = BASE_DIR + 'image_class_labels.txt'\n",
    "SIZE_FILE    = BASE_DIR + 'sizes.txt'\n",
    "SPLIT_FILE   = BASE_DIR + 'train_test_split.txt'\n",
    "\n",
    "TRAIN_LST_FILE = 'birds_ssd_train.lst'\n",
    "VAL_LST_FILE   = 'birds_ssd_val.lst'\n",
    "\n",
    "if (SAMPLE_ONLY):\n",
    "    TRAIN_LST_FILE = 'birds_ssd_sample_train.lst'\n",
    "    VAL_LST_FILE   = 'birds_ssd_sample_val.lst'\n",
    "\n",
    "TRAIN_RATIO     = 0.8\n",
    "CLASS_COLS      = ['class_number','class_id']\n",
    "IM2REC_SSD_COLS = ['header_cols', 'label_width', 'zero_based_id', 'xmin', 'ymin', 'xmax', 'ymax', 'image_file_name']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Explore the dataset images\n",
    "\n",
    "For each species, there are dozens of images of various shapes and sizes. By dividing the entire dataset into individual named (numbered) folders, the images are in effect labelled for supervised learning using image classification and object detection algorithms. \n",
    "\n",
    "The following function displays a grid of thumbnail images for all the image files for a given species."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_species(species_id):\n",
    "    _im_list = !ls $IMAGES_DIR/$species_id\n",
    "\n",
    "    NUM_COLS = 6\n",
    "    IM_COUNT = len(_im_list)\n",
    "\n",
    "    print('Species ' + species_id + ' has ' + str(IM_COUNT) + ' images.')\n",
    "    \n",
    "    NUM_ROWS = int(IM_COUNT / NUM_COLS)\n",
    "    if ((IM_COUNT % NUM_COLS) > 0):\n",
    "        NUM_ROWS += 1\n",
    "\n",
    "    fig, axarr = plt.subplots(NUM_ROWS, NUM_COLS)\n",
    "    fig.set_size_inches(8.0, 16.0, forward=True)\n",
    "\n",
    "    curr_row = 0\n",
    "    for curr_img in range(IM_COUNT):\n",
    "        # fetch the url as a file type object, then read the image\n",
    "        f = IMAGES_DIR + species_id + '/' + _im_list[curr_img]\n",
    "        a = plt.imread(f)\n",
    "\n",
    "        # find the column by taking the current index modulo 3\n",
    "        col = curr_img % NUM_ROWS\n",
    "        # plot on relevant subplot\n",
    "        axarr[col, curr_row].imshow(a)\n",
    "        if col == (NUM_ROWS - 1):\n",
    "            # we have finished the current row, so increment row counter\n",
    "            curr_row += 1\n",
    "\n",
    "    fig.tight_layout()       \n",
    "    plt.show()\n",
    "        \n",
    "    # Clean up\n",
    "    plt.clf()\n",
    "    plt.cla()\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Show the list of bird species or dataset classes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "classes_df = pd.read_csv(CLASSES_FILE, sep=' ', names=CLASS_COLS, header=None)\n",
    "criteria = classes_df['class_number'].isin(CLASSES)\n",
    "classes_df = classes_df[criteria]\n",
    "print(classes_df.to_csv(columns=['class_id'], sep='\\t', index=False, header=False))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now for any given species, display thumbnail images of each of the images provided for training and testing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_species('017.Cardinal')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generate RecordIO files"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1. Gather image sizes\n",
    "\n",
    "For this particular dataset, bounding box annotations are specified in absolute terms.  RecordIO format requires them to be defined in terms relative to the image size.  The following code visits each image, extracts the height and width, and saves this information into a file for subsequent use.  Some other publicly available datasets provide such a file for exactly this purpose. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "SIZE_COLS = ['idx','width','height']\n",
    "\n",
    "def gen_image_size_file():\n",
    "    print('Generating a file containing image sizes...')\n",
    "    images_df = pd.read_csv(IMAGE_FILE, sep=' ',\n",
    "                            names=['image_pretty_name', 'image_file_name'],\n",
    "                            header=None)\n",
    "    rows_list = []\n",
    "    idx = 0\n",
    "    for i in images_df['image_file_name']:\n",
    "        # TODO: add progress bar\n",
    "        idx += 1\n",
    "        img = cv2.imread(IMAGES_DIR + i)\n",
    "        dimensions = img.shape\n",
    "        height = img.shape[0]\n",
    "        width = img.shape[1]\n",
    "        image_dict = {'idx': idx, 'width': width, 'height': height}\n",
    "        rows_list.append(image_dict)\n",
    "\n",
    "    sizes_df = pd.DataFrame(rows_list)\n",
    "    print('Image sizes:\\n' + str(sizes_df.head()))\n",
    "\n",
    "    sizes_df[SIZE_COLS].to_csv(SIZE_FILE, sep=' ', index=False, header=None)\n",
    "\n",
    "gen_image_size_file()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2. Generate list files for producing RecordIO files \n",
    "\n",
    "[RecordIO](https://mxnet.incubator.apache.org/architecture/note_data_loading.html) files can be created using the [im2rec tool](https://mxnet.incubator.apache.org/faq/recordio.html) (images to RecordIO), which takes as input a pair of list files, one for training images and the other for validation images.  Each list file has one row for each image.  For object detection, each row must contain bounding box data and a class label.\n",
    "\n",
    "For the CalTech birds dataset, we need to convert absolute bounding box dimensions to relative dimensions based on image size.  We also need to adjust class id's to be zero-based (instead of 1 to 200, they need to be 0 to 199).  This  dataset comes with recommended train/test split information (\"is_training_image\" flag).  This notebook is built flexibly to either leverage this suggestion, or to create a random train/test split with a specific train/test ratio.  The `RAMDOM_SPLIT` variable defined earlier controls whether or not the split happens randomly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def split_to_train_test(df, label_column, train_frac=0.8):\n",
    "    train_df, test_df = pd.DataFrame(), pd.DataFrame()\n",
    "    labels = df[label_column].unique()\n",
    "    for lbl in labels:\n",
    "        lbl_df = df[df[label_column] == lbl]\n",
    "        lbl_train_df = lbl_df.sample(frac=train_frac)\n",
    "        lbl_test_df = lbl_df.drop(lbl_train_df.index)\n",
    "        print('\\n{}:\\n---------\\ntotal:{}\\ntrain_df:{}\\ntest_df:{}'.format(lbl, len(lbl_df), len(lbl_train_df), len(lbl_test_df)))\n",
    "        train_df = train_df.append(lbl_train_df)\n",
    "        test_df = test_df.append(lbl_test_df)\n",
    "    return train_df, test_df\n",
    "\n",
    "def gen_list_files():\n",
    "    # use generated sizes file\n",
    "    sizes_df = pd.read_csv(SIZE_FILE, sep=' ',\n",
    "                names=['image_pretty_name', 'width', 'height'],\n",
    "                header=None)\n",
    "    bboxes_df = pd.read_csv(BBOX_FILE, sep=' ',\n",
    "                names=['image_pretty_name', 'x_abs', 'y_abs', 'bbox_width', 'bbox_height'],\n",
    "                header=None)\n",
    "    split_df = pd.read_csv(SPLIT_FILE, sep=' ',\n",
    "                            names=['image_pretty_name', 'is_training_image'],\n",
    "                            header=None)\n",
    "    print(IMAGE_FILE)\n",
    "    images_df = pd.read_csv(IMAGE_FILE, sep=' ',\n",
    "                            names=['image_pretty_name', 'image_file_name'],\n",
    "                            header=None)\n",
    "    print('num images total: ' + str(images_df.shape[0]))\n",
    "    image_class_labels_df = pd.read_csv(LABEL_FILE, sep=' ',\n",
    "                                names=['image_pretty_name', 'class_id'], header=None)\n",
    "\n",
    "    # Merge the metadata into a single flat dataframe for easier processing\n",
    "    full_df = pd.DataFrame(images_df)\n",
    "    full_df.reset_index(inplace=True)\n",
    "    full_df = pd.merge(full_df, image_class_labels_df, on='image_pretty_name')\n",
    "    full_df = pd.merge(full_df, sizes_df, on='image_pretty_name')\n",
    "    full_df = pd.merge(full_df, bboxes_df, on='image_pretty_name')\n",
    "    full_df = pd.merge(full_df, split_df, on='image_pretty_name')\n",
    "    full_df.sort_values(by=['index'], inplace=True)\n",
    "\n",
    "    # Define the bounding boxes in the format required by SageMaker's built in Object Detection algorithm.\n",
    "    # the xmin/ymin/xmax/ymax parameters are specified as ratios to the total image pixel size\n",
    "    full_df['header_cols'] = 2  # one col for the number of header cols, one for the label width\n",
    "    full_df['label_width'] = 5  # number of cols for each label: class, xmin, ymin, xmax, ymax\n",
    "    full_df['xmin'] = full_df['x_abs'] / full_df['width']\n",
    "    full_df['xmax'] = (full_df['x_abs'] + full_df['bbox_width']) / full_df['width']\n",
    "    full_df['ymin'] = full_df['y_abs'] / full_df['height']\n",
    "    full_df['ymax'] = (full_df['y_abs'] + full_df['bbox_height']) / full_df['height']\n",
    "\n",
    "    # object detection class id's must be zero based. map from\n",
    "    # class_id's given by CUB to zero-based (1 is 0, and 200 is 199).\n",
    "\n",
    "    if SAMPLE_ONLY:\n",
    "        # grab a small subset of species for testing\n",
    "        criteria = full_df['class_id'].isin(CLASSES)\n",
    "        full_df = full_df[criteria]\n",
    "\n",
    "    unique_classes = full_df['class_id'].drop_duplicates()\n",
    "    sorted_unique_classes = sorted(unique_classes)\n",
    "\n",
    "    id_to_zero = {}\n",
    "    i = 0.0\n",
    "    for c in sorted_unique_classes:\n",
    "        id_to_zero[c] = i\n",
    "        i += 1.0\n",
    "\n",
    "    full_df['zero_based_id'] = full_df['class_id'].map(id_to_zero)\n",
    "\n",
    "    full_df.reset_index(inplace=True)\n",
    "\n",
    "    # use 4 decimal places, as it seems to be required by the Object Detection algorithm\n",
    "    pd.set_option(\"display.precision\", 4)\n",
    "\n",
    "    train_df = []\n",
    "    val_df = []\n",
    "\n",
    "    if (RANDOM_SPLIT):\n",
    "        # split into training and validation sets\n",
    "        train_df, val_df = split_to_train_test(full_df, 'class_id', TRAIN_RATIO)\n",
    "\n",
    "        train_df[IM2REC_SSD_COLS].to_csv(TRAIN_LST_FILE, sep='\\t',\n",
    "                float_format='%.4f', header=None)\n",
    "        val_df[IM2REC_SSD_COLS].to_csv(  VAL_LST_FILE, sep='\\t',\n",
    "                float_format='%.4f', header=None)\n",
    "    else:\n",
    "        train_df = full_df[(full_df.is_training_image == 1)]\n",
    "        train_df[IM2REC_SSD_COLS].to_csv(TRAIN_LST_FILE, sep='\\t',\n",
    "                float_format='%.4f', header=None)\n",
    "\n",
    "        val_df = full_df[(full_df.is_training_image == 0)]\n",
    "        val_df[IM2REC_SSD_COLS].to_csv(  VAL_LST_FILE, sep='\\t',\n",
    "                float_format='%.4f', header=None)\n",
    "        \n",
    "    print('num train: ' + str(train_df.shape[0]))\n",
    "    print('num val: ' + str(val_df.shape[0]))\n",
    "    return train_df, val_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_df, val_df = gen_list_files()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here we take a look at a few records from the training list file to understand better what is being fed to the RecordIO files.\n",
    "\n",
    "The first column is the image number or index.  The second column indicates that the label is made up of 2 columns (column 2 and column 3).  The third column specifies the label width of a single object. In our case, the value 5 indicates each image has 5 numbers to describe its label information: the class index, and the 4 bounding box coordinates. If there are multiple objects within one image, all the label information should be listed in one line. Our dataset contains only one bounding box per image.\n",
    "\n",
    "The fourth column is the class label.  This identifies the bird species using a zero-based class id.  Columns 4 through 7 represent the bounding box for where the bird is found in this image.\n",
    "\n",
    "The classes should be labeled with successive numbers and start with 0. The bounding box coordinates are ratios of its top-left (xmin, ymin) and bottom-right (xmax, ymax) corner indices to the overall image size. Note that the top-left corner of the entire image is the origin (0, 0). The last column specifies the relative path of the image file within the images directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!tail -3 $TRAIN_LST_FILE"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2. Convert data into RecordIO format\n",
    "\n",
    "Now we create im2rec databases (.rec files) for training and validation based on the list files created earlier."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!python tools/im2rec.py --resize $RESIZE_SIZE --pack-label birds_ssd_sample $BASE_DIR/images/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3. Upload RecordIO files to S3\n",
    "Upload the training and validation data to the S3 bucket. We do this in multiple channels. Channels are simply directories in the bucket that differentiate the types of data provided to the algorithm. For the object detection algorithm, we call these directories `train` and `validation`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Upload the RecordIO files to train and validation channels\n",
    "train_channel = prefix + '/train'\n",
    "validation_channel = prefix + '/validation'\n",
    "\n",
    "sess.upload_data(path='birds_ssd_sample_train.rec', bucket=bucket, key_prefix=train_channel)\n",
    "sess.upload_data(path='birds_ssd_sample_val.rec', bucket=bucket, key_prefix=validation_channel)\n",
    "\n",
    "s3_train_data = 's3://{}/{}'.format(bucket, train_channel)\n",
    "s3_validation_data = 's3://{}/{}'.format(bucket, validation_channel)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train the model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we define an output location in S3, where the model artifacts will be placed on completion of the training. These artifacts are the output of the algorithm's traning job.  We also get the URI to the Amazon SageMaker Object Detection docker image.  This ensures the estimator uses the correct algorithm from the current region."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.amazon.amazon_estimator import get_image_uri\n",
    "\n",
    "training_image = get_image_uri(sess.boto_region_name, 'object-detection', repo_version='latest')\n",
    "print (training_image)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "s3_output_location = 's3://{}/{}/output'.format(bucket, prefix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "od_model = sagemaker.estimator.Estimator(training_image,\n",
    "                                         role, \n",
    "                                         train_instance_count=1, \n",
    "                                         train_instance_type='ml.p3.2xlarge',\n",
    "                                         train_volume_size = 50,\n",
    "                                         train_max_run = 360000,\n",
    "                                         input_mode= 'File',\n",
    "                                         output_path=s3_output_location,\n",
    "                                         sagemaker_session=sess)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define hyperparameters"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The object detection algorithm at its core is the [Single-Shot Multi-Box detection algorithm (SSD)](https://arxiv.org/abs/1512.02325). This algorithm uses a `base_network`, which is typically a [VGG](https://arxiv.org/abs/1409.1556) or a [ResNet](https://arxiv.org/abs/1512.03385). The Amazon SageMaker object detection algorithm supports VGG-16 and ResNet-50. It also has a number of hyperparameters that help configure the training job. The next step in our training, is to setup these hyperparameters and data channels for training the model. See the SageMaker Object Detection [documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/object-detection.html) for more details on its specific hyperparameters.\n",
    "\n",
    "One of the hyperparameters here for example is `epochs`. This defines how many passes of the dataset we iterate over and drives the training time of the algorithm. Based on our tests, we can achieve 70% accuracy on a sample mix of 5 species with 100 epochs.  When using the full 200 species, we can achieve 52% accuracy with 1,200 epochs.\n",
    "\n",
    "Note that Amazon SageMaker also provides [Automatic Model Tuning](https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning.html).  Automatic model tuning, also known as hyperparameter tuning, finds the best version of a model by running many training jobs on your dataset using the algorithm and ranges of hyperparameters that you specify. It then chooses the hyperparameter values that result in a model that performs the best, as measured by a metric that you choose.  When [tuning an Object Detection](https://docs.aws.amazon.com/sagemaker/latest/dg/object-detection-tuning.html) algorithm for example, the tuning job could find the best `validation:mAP` score by trying out various values for certain hyperparameters such as `mini_batch_size`, `weight_decay`, and `momentum`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_hyperparameters(num_epochs, lr_steps):\n",
    "    num_classes = classes_df.shape[0]\n",
    "    num_training_samples = train_df.shape[0]\n",
    "    print('num classes: {}, num training images: {}'.format(num_classes, num_training_samples))\n",
    "\n",
    "    od_model.set_hyperparameters(base_network='resnet-50',\n",
    "                                 use_pretrained_model=1,\n",
    "                                 num_classes=num_classes,\n",
    "                                 mini_batch_size=16,\n",
    "                                 epochs=num_epochs,               \n",
    "                                 learning_rate=0.001, \n",
    "                                 lr_scheduler_step=lr_steps,      \n",
    "                                 lr_scheduler_factor=0.1,\n",
    "                                 optimizer='sgd',\n",
    "                                 momentum=0.9,\n",
    "                                 weight_decay=0.0005,\n",
    "                                 overlap_threshold=0.5,\n",
    "                                 nms_threshold=0.45,\n",
    "                                 image_shape=512,\n",
    "                                 label_width=350,\n",
    "                                 num_training_samples=num_training_samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "set_hyperparameters(100, '33,67')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that the hyperparameters are setup, we define the data channels to be passed to the algorithm. To do this, we need to create the `sagemaker.session.s3_input` objects from our data channels. These objects are then put in a simple dictionary, which the algorithm consumes.  Note that you could add a third channel named `model` to perform incremental training (continue training from where you had left off with a prior model)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = sagemaker.session.s3_input(s3_train_data, distribution='FullyReplicated', \n",
    "                        content_type='application/x-recordio', s3_data_type='S3Prefix')\n",
    "validation_data = sagemaker.session.s3_input(s3_validation_data, distribution='FullyReplicated', \n",
    "                             content_type='application/x-recordio', s3_data_type='S3Prefix')\n",
    "data_channels = {'train': train_data, 'validation': validation_data}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Submit training job"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We have our `Estimator` object, we have set the hyperparameters for this object, and we have our data channels linked with the algorithm. The only remaining thing to do is to train the algorithm using the `fit` method. This will take more than 10 minutes in our example.\n",
    "\n",
    "The training process involves a few steps. First, the instances that we requested while creating the `Estimator` classes are provisioned and setup with the appropriate libraries. Then, the data from our channels are downloaded into the instance. Once this is done, the actual training begins. The provisioning and data downloading will take time, depending on the size of the data. Therefore it might be a few minutes before our training job logs show up in CloudWatch. The logs will also print out Mean Average Precision (mAP) on the validation data, among other losses, for every run of the dataset (once per epoch). This metric is a proxy for the accuracy of the model.\n",
    "\n",
    "Once the job has finished, a `Job complete` message will be printed. The trained model artifacts can be found in the S3 bucket that was setup as `output_path` in the estimator."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "od_model.fit(inputs=data_channels, logs=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that the training job is complete, you can also see the job listed in the `Training jobs` section of your SageMaker console.  Note that the job name is uniquely identified by the name of the algorithm concatenated with the date and time stamp.  You can click on the job to see the details including the hyperparameters, the data channel definitions, and the full path to the resulting model artifacts.  You could even clone the job from the console, and tweak some of the parameters to generate a new training job."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Without having to go to the CloudWatch console, you can see how the job progressed in terms of the key object detection algorithm metric, mean average precision (mAP).  This function below prepares a simple chart of that metric against the epochs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import boto3\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.ticker as ticker\n",
    "%matplotlib inline\n",
    "\n",
    "client = boto3.client('logs')\n",
    "BASE_LOG_NAME = '/aws/sagemaker/TrainingJobs'\n",
    "\n",
    "def plot_object_detection_log(model, title):\n",
    "    logs = client.describe_log_streams(logGroupName=BASE_LOG_NAME, logStreamNamePrefix=model._current_job_name)\n",
    "    cw_log = client.get_log_events(logGroupName=BASE_LOG_NAME, logStreamName=logs['logStreams'][0]['logStreamName'])\n",
    "\n",
    "    mAP_accs=[]\n",
    "    for e in cw_log['events']:\n",
    "        msg = e['message']\n",
    "        if 'validation mAP <score>=' in msg:\n",
    "            num_start = msg.find('(')\n",
    "            num_end = msg.find(')')\n",
    "            mAP = msg[num_start+1:num_end]\n",
    "            mAP_accs.append(float(mAP))\n",
    "\n",
    "    print(title)\n",
    "    print('Maximum mAP: %f ' % max(mAP_accs))\n",
    "\n",
    "    fig, ax = plt.subplots()\n",
    "    plt.xlabel('Epochs')\n",
    "    plt.ylabel('Mean Avg Precision (mAP)')\n",
    "    val_plot,   = ax.plot(range(len(mAP_accs)),   mAP_accs,   label='mAP')\n",
    "    plt.legend(handles=[val_plot])\n",
    "    ax.yaxis.set_ticks(np.arange(0.0, 1.05, 0.1))\n",
    "    ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%0.2f'))\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_object_detection_log(od_model, 'mAP tracking for job: ' + od_model._current_job_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Host the model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once the training is done, we can deploy the trained model as an Amazon SageMaker real-time hosted endpoint. This lets us make predictions (or inferences) from the model. Note that we don't have to host using the same type of instance that we used to train. Training is a prolonged and compute heavy job with different compute and memory requirements that hosting typically does not. In our case we chose the `ml.p3.2xlarge` instance to train, but we choose to host the model on the less expensive cpu instance, `ml.m4.xlarge`. The endpoint deployment takes several minutes, and can be accomplished with a single line of code calling the `deploy` method.\n",
    "\n",
    "Note that some use cases require large sets of inferences on a predefined body of images.  In those cases, you do not need to make the inferences in real time.  Instead, you could use SageMaker's [batch transform jobs](https://docs.aws.amazon.com/sagemaker/latest/dg/how-it-works-batch.html)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "object_detector = od_model.deploy(initial_instance_count = 1,\n",
    "                                 instance_type = 'ml.m4.xlarge')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Test the model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that the trained model is deployed at an endpoint that is up-and-running, we can use this endpoint for inference.  The results of a call to the inference endpoint are in a format that is similar to the .lst format, with the addition of a confidence score for each detected object. The format of the output can be represented as `[class_index, confidence_score, xmin, ymin, xmax, ymax]`. Typically, we don't visualize low-confidence predictions.\n",
    "\n",
    "We have provided a script to easily visualize the detection outputs. You can visulize the high-confidence preditions with bounding box by filtering out low-confidence detections using the script below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def visualize_detection(img_file, dets, classes=[], thresh=0.6):\n",
    "        \"\"\"\n",
    "        visualize detections in one image\n",
    "        Parameters:\n",
    "        ----------\n",
    "        img : numpy.array\n",
    "            image, in bgr format\n",
    "        dets : numpy.array\n",
    "            ssd detections, numpy.array([[id, score, x1, y1, x2, y2]...])\n",
    "            each row is one object\n",
    "        classes : tuple or list of str\n",
    "            class names\n",
    "        thresh : float\n",
    "            score threshold\n",
    "        \"\"\"\n",
    "        import random\n",
    "        import matplotlib.pyplot as plt\n",
    "        import matplotlib.image as mpimg\n",
    "\n",
    "        img = mpimg.imread(img_file)\n",
    "        plt.imshow(img)\n",
    "        height = img.shape[0]\n",
    "        width  = img.shape[1]\n",
    "        colors = dict()\n",
    "        num_detections = 0\n",
    "        for det in dets:\n",
    "            (klass, score, x0, y0, x1, y1) = det\n",
    "            if score < thresh:\n",
    "                continue\n",
    "            num_detections += 1\n",
    "            cls_id = int(klass)\n",
    "            if cls_id not in colors:\n",
    "                colors[cls_id] = (random.random(), random.random(), random.random())\n",
    "            xmin = int(x0 * width)\n",
    "            ymin = int(y0 * height)\n",
    "            xmax = int(x1 * width)\n",
    "            ymax = int(y1 * height)\n",
    "            rect = plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False,\n",
    "                                 edgecolor=colors[cls_id], linewidth=3.5)\n",
    "            plt.gca().add_patch(rect)\n",
    "            class_name = str(cls_id)\n",
    "            if classes and len(classes) > cls_id:\n",
    "                class_name = classes[cls_id]\n",
    "            print('{},{}'.format(class_name,score))\n",
    "            plt.gca().text(xmin, ymin - 2,\n",
    "                            '{:s} {:.3f}'.format(class_name, score),\n",
    "                            bbox=dict(facecolor=colors[cls_id], alpha=0.5),\n",
    "                                    fontsize=12, color='white')\n",
    "\n",
    "        print('Number of detections: ' + str(num_detections))\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we use our endpoint to try to detect objects within an image. Since the image is a jpeg, we use the appropriate content_type to run the prediction. The endpoint returns a JSON object that we can simply load and peek into. We have packaged the prediction code into a function to make it easier to test other images.  Note that we are defaulting the confidence threshold to 30% in our example, as a couple of the birds in our sample images were not being detected as clearly.  Defining an appropriate threshold is entirely dependent on your use case."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "OBJECT_CATEGORIES = classes_df['class_id'].values.tolist()\n",
    "\n",
    "def show_bird_prediction(filename, ep, thresh=0.40):\n",
    "    b = ''\n",
    "    with open(filename, 'rb') as image:\n",
    "        f = image.read()\n",
    "        b = bytearray(f)\n",
    "    endpoint_response = runtime.invoke_endpoint(EndpointName=ep,\n",
    "                                           ContentType='image/jpeg',\n",
    "                                           Body=b)\n",
    "    results = endpoint_response['Body'].read()\n",
    "    detections = json.loads(results)\n",
    "    visualize_detection(filename, detections['prediction'], OBJECT_CATEGORIES, thresh)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here we download images that the algorithm has not yet seen."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!wget -q -O multi-goldfinch-1.jpg https://t3.ftcdn.net/jpg/01/44/64/36/500_F_144643697_GJRUBtGc55KYSMpyg1Kucb9yJzvMQooW.jpg\n",
    "!wget -q -O northern-flicker-1.jpg https://upload.wikimedia.org/wikipedia/commons/5/5c/Northern_Flicker_%28Red-shafted%29.jpg\n",
    "!wget -q -O northern-cardinal-1.jpg https://cdn.pixabay.com/photo/2013/03/19/04/42/bird-94957_960_720.jpg\n",
    "!wget -q -O blue-jay-1.jpg https://cdn12.picryl.com/photo/2016/12/31/blue-jay-bird-feather-animals-b8ee04-1024.jpg\n",
    "!wget -q -O hummingbird-1.jpg http://res.freestockphotos.biz/pictures/17/17875-hummingbird-close-up-pv.jpg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_model():\n",
    "    show_bird_prediction('hummingbird-1.jpg', object_detector.endpoint)\n",
    "    show_bird_prediction('blue-jay-1.jpg', object_detector.endpoint)\n",
    "    show_bird_prediction('multi-goldfinch-1.jpg', object_detector.endpoint)\n",
    "    show_bird_prediction('northern-flicker-1.jpg', object_detector.endpoint)\n",
    "    show_bird_prediction('northern-cardinal-1.jpg', object_detector.endpoint)\n",
    "\n",
    "test_model()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Clean up\n",
    "Here we delete the SageMaker endpoint, as we will no longer be performing any inferences.  This is an important step, as your account is billed for the amount of time an endpoint is running, even when it is idle."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sagemaker.Session().delete_endpoint(object_detector.endpoint)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Improve the model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define Function to Flip the Images Horizontally (on the X Axis)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "\n",
    "def flip_images():\n",
    "    print('Flipping images...')\n",
    "    \n",
    "    SIZE_COLS  = ['idx','width','height']\n",
    "    IMAGE_COLS = ['image_pretty_name','image_file_name']\n",
    "    LABEL_COLS = ['image_pretty_name','class_id']\n",
    "    BBOX_COLS  = ['image_pretty_name', 'x_abs', 'y_abs', 'bbox_width', 'bbox_height']\n",
    "    SPLIT_COLS = ['image_pretty_name', 'is_training_image']\n",
    "\n",
    "    images_df = pd.read_csv(BASE_DIR + 'images.txt',\n",
    "                            sep=' ', names=IMAGE_COLS, header=None)\n",
    "    image_class_labels_df = pd.read_csv(BASE_DIR + 'image_class_labels.txt',\n",
    "                            sep=' ', names=LABEL_COLS, header=None)\n",
    "    bboxes_df = pd.read_csv(BASE_DIR + 'bounding_boxes.txt',\n",
    "                            sep=' ', names=BBOX_COLS, header=None)\n",
    "    split_df = pd.read_csv(BASE_DIR + 'train_test_split.txt',\n",
    "                            sep=' ', names=SPLIT_COLS, header=None)\n",
    "\n",
    "    NUM_ORIGINAL_IMAGES = images_df.shape[0]\n",
    "\n",
    "    rows_list        = []\n",
    "    bbox_rows_list   = []\n",
    "    size_rows_list   = []\n",
    "    label_rows_list  = []\n",
    "    split_rows_list  = []\n",
    "\n",
    "    idx = 0\n",
    "\n",
    "    full_df = images_df.copy()\n",
    "    full_df.reset_index(inplace=True)\n",
    "    full_df = pd.merge(full_df, image_class_labels_df, on='image_pretty_name')\n",
    "    full_df = pd.merge(full_df, bboxes_df, on='image_pretty_name')\n",
    "    full_df = pd.merge(full_df, split_df, on='image_pretty_name')\n",
    "    full_df.sort_values(by=['index'], inplace=True)\n",
    "\n",
    "    if SAMPLE_ONLY:\n",
    "        # grab a small subset of species for testing\n",
    "        criteria = full_df['class_id'].isin(CLASSES)\n",
    "        full_df = full_df[criteria]\n",
    "\n",
    "    for rel_image_fn in full_df['image_file_name']:\n",
    "        idx += 1\n",
    "        full_img_content = full_df[(full_df.image_file_name == rel_image_fn)]\n",
    "\n",
    "        class_id = full_img_content.iloc[0].class_id\n",
    "\n",
    "        img = Image.open(IMAGES_DIR + rel_image_fn)\n",
    "\n",
    "        width, height = img.size\n",
    "\n",
    "        new_idx = idx + NUM_ORIGINAL_IMAGES\n",
    "\n",
    "        flip_core_file_name = rel_image_fn[:-4] + '_flip.jpg'\n",
    "        flip_full_file_name = IMAGES_DIR + flip_core_file_name\n",
    "\n",
    "        img_flip = img.transpose(Image.FLIP_LEFT_RIGHT)\n",
    "        img_flip.save(flip_full_file_name)\n",
    "\n",
    "        # append a new image\n",
    "        dict = {'image_pretty_name': new_idx, 'image_file_name': flip_core_file_name}\n",
    "        rows_list.append(dict)\n",
    "\n",
    "        # append a new split, use same flag for flipped image from original image\n",
    "        is_training_image = full_img_content.iloc[0].is_training_image\n",
    "        split_dict = {'image_pretty_name': new_idx, 'is_training_image': is_training_image}\n",
    "        split_rows_list.append(split_dict)\n",
    "\n",
    "        # append a new image class label\n",
    "        label_dict = {'image_pretty_name': new_idx, 'class_id': class_id}\n",
    "        label_rows_list.append(label_dict)\n",
    "\n",
    "        # add a size row for the original and the flipped image, same height and width\n",
    "        size_dict = {'idx': idx, 'width': width, 'height': height}\n",
    "        size_rows_list.append(size_dict)\n",
    "\n",
    "        size_dict = {'idx': new_idx, 'width': width, 'height': height}\n",
    "        size_rows_list.append(size_dict)\n",
    "\n",
    "        # append bounding box for flipped image\n",
    "\n",
    "        x_abs = full_img_content.iloc[0].x_abs\n",
    "        y_abs = full_img_content.iloc[0].y_abs\n",
    "        bbox_width  = full_img_content.iloc[0].bbox_width\n",
    "        bbox_height = full_img_content.iloc[0].bbox_height\n",
    "        flipped_x_abs = width - bbox_width - x_abs\n",
    "\n",
    "        bbox_dict = {'image_pretty_name': new_idx, 'x_abs': flipped_x_abs,\n",
    "                    'y_abs': y_abs, 'bbox_width': bbox_width, 'bbox_height': bbox_height}\n",
    "        bbox_rows_list.append(bbox_dict)\n",
    "\n",
    "    print('Done looping through original images')\n",
    "\n",
    "    images_df = images_df.append(rows_list)\n",
    "    images_df[IMAGE_COLS].to_csv(IMAGE_FILE, sep=' ', index=False, header=None)\n",
    "    bboxes_df = bboxes_df.append(bbox_rows_list)\n",
    "    bboxes_df[BBOX_COLS].to_csv(BBOX_FILE, sep=' ', index=False, header=None)\n",
    "    split_df = split_df.append(split_rows_list)\n",
    "    split_df[SPLIT_COLS].to_csv(SPLIT_FILE, sep=' ', index=False, header=None)\n",
    "    sizes_df = pd.DataFrame(size_rows_list)\n",
    "    sizes_df[SIZE_COLS].to_csv(SIZE_FILE, sep=' ', index=False, header=None)\n",
    "    image_class_labels_df = image_class_labels_df.append(label_rows_list)\n",
    "    image_class_labels_df[LABEL_COLS].to_csv(LABEL_FILE, sep=' ', index=False, header=None)\n",
    "\n",
    "    print('Done saving metadata in text files')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Re-train the model with the expanded dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "BBOX_FILE  = BASE_DIR + 'bounding_boxes_with_flip.txt'\n",
    "IMAGE_FILE = BASE_DIR + 'images_with_flip.txt'\n",
    "LABEL_FILE = BASE_DIR + 'image_class_labels_with_flip.txt'\n",
    "SIZE_FILE  = BASE_DIR + 'sizes_with_flip.txt'\n",
    "SPLIT_FILE = BASE_DIR + 'train_test_split_with_flip.txt'\n",
    "\n",
    "# add a set of flipped images\n",
    "flip_images()\n",
    "\n",
    "# show the new full set of images for a species\n",
    "show_species('017.Cardinal')\n",
    "\n",
    "# create new sizes file\n",
    "gen_image_size_file()\n",
    "\n",
    "# re-create and re-deploy the RecordIO files with the updated set of images\n",
    "train_df, val_df = gen_list_files()\n",
    "!python tools/im2rec.py --resize $RESIZE_SIZE --pack-label birds_ssd_sample $BASE_DIR/images/\n",
    "sess.upload_data(path='birds_ssd_sample_train.rec', bucket=bucket, key_prefix=train_channel)\n",
    "sess.upload_data(path='birds_ssd_sample_val.rec', bucket=bucket, key_prefix=validation_channel)\n",
    "\n",
    "# account for the new number of training images\n",
    "set_hyperparameters(100, '33,67')\n",
    "\n",
    "# re-train\n",
    "od_model.fit(inputs=data_channels, logs=True)\n",
    "\n",
    "# check out the new accuracy\n",
    "plot_object_detection_log(od_model, 'mAP tracking for job: ' + od_model._current_job_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Re-deploy and test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# host the updated model\n",
    "object_detector = od_model.deploy(initial_instance_count = 1, instance_type = 'ml.m4.xlarge')\n",
    "\n",
    "# test the new model\n",
    "test_model()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Final cleanup\n",
    "Here we delete the SageMaker endpoint, as we will no longer be performing any inferences.  This is an important step, as your account is billed for the amount of time an endpoint is running, even when it is idle."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# delete the new endpoint\n",
    "sagemaker.Session().delete_endpoint(object_detector.endpoint)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
